{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.212768Z",
     "start_time": "2020-11-17T13:24:05.510414Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=2)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "from collections import Counter\n",
    "\n",
    "sns.set_style('ticks')\n",
    "\n",
    "%matplotlib inline\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi']= 300\n",
    "mpl.rc(\"savefig\", dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Read files and select drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.217543Z",
     "start_time": "2020-11-17T13:24:06.214411Z"
    }
   },
   "outputs": [],
   "source": [
    "ref_type = 'log2_median_ic50_hn' # log2_median_ic50_3f_hn | log2_median_ic50_hn\n",
    "model_name = 'RWEN'\n",
    "\n",
    "# shift the dosage as GDSC experiment (Syto60) is less sensitive\n",
    "dosage_shifted = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.230761Z",
     "start_time": "2020-11-17T13:24:06.220771Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(81, 27)\n"
     ]
    }
   ],
   "source": [
    "drug_info_df = pd.read_csv('../preprocessed_data/GDSC/hn_drug_stat.csv', index_col=0)\n",
    "drug_info_df.index = drug_info_df.index.astype(str)\n",
    "\n",
    "drug_id_name_dict = dict(zip(drug_info_df.index, drug_info_df['Drug Name']))\n",
    "print (drug_info_df.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.239141Z",
     "start_time": "2020-11-17T13:24:06.232645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tested_drug_list = [1032, 1007, 133, 201, 1010] + [182, 301, 302]\n",
    "[d for d in tested_drug_list if d not in drug_info_df.index.astype(int)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Read predicted IC50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.243682Z",
     "start_time": "2020-11-17T13:24:06.241128Z"
    }
   },
   "outputs": [],
   "source": [
    "norm_type = 'cell_TPM'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.256162Z",
     "start_time": "2020-11-17T13:24:06.245419Z"
    }
   },
   "outputs": [],
   "source": [
    "cadrres_cell_df = pd.read_csv('../result/HN_model/{}/{}_pred.csv'.format(norm_type, model_name), index_col=0)\n",
    "out_dir = '../result/HN_model/{}/'.format(norm_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:06.271491Z",
     "start_time": "2020-11-17T13:24:06.257886Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1007</th>\n",
       "      <th>133</th>\n",
       "      <th>201</th>\n",
       "      <th>1010</th>\n",
       "      <th>182</th>\n",
       "      <th>301</th>\n",
       "      <th>302</th>\n",
       "      <th>1012</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>RHH2176</th>\n",
       "      <td>-6.395335</td>\n",
       "      <td>-6.514972</td>\n",
       "      <td>-12.464670</td>\n",
       "      <td>-6.430358</td>\n",
       "      <td>-3.722897</td>\n",
       "      <td>0.708929</td>\n",
       "      <td>-0.279589</td>\n",
       "      <td>-0.499479</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2177</th>\n",
       "      <td>-10.289315</td>\n",
       "      <td>-6.113236</td>\n",
       "      <td>-9.612803</td>\n",
       "      <td>0.756091</td>\n",
       "      <td>-4.753741</td>\n",
       "      <td>-3.594752</td>\n",
       "      <td>-0.380591</td>\n",
       "      <td>-0.011896</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2178</th>\n",
       "      <td>-10.098690</td>\n",
       "      <td>-5.891169</td>\n",
       "      <td>-12.298502</td>\n",
       "      <td>1.642715</td>\n",
       "      <td>-3.403522</td>\n",
       "      <td>1.604877</td>\n",
       "      <td>0.143180</td>\n",
       "      <td>-1.247750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2179</th>\n",
       "      <td>-9.486539</td>\n",
       "      <td>-5.705740</td>\n",
       "      <td>-12.306209</td>\n",
       "      <td>1.522135</td>\n",
       "      <td>-5.242828</td>\n",
       "      <td>-1.395477</td>\n",
       "      <td>1.562695</td>\n",
       "      <td>-2.288641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2180</th>\n",
       "      <td>-11.620982</td>\n",
       "      <td>-6.841660</td>\n",
       "      <td>-7.702711</td>\n",
       "      <td>3.060791</td>\n",
       "      <td>-4.710782</td>\n",
       "      <td>-1.831584</td>\n",
       "      <td>-5.050828</td>\n",
       "      <td>-1.279218</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              1007       133        201      1010       182       301  \\\n",
       "RHH2176  -6.395335 -6.514972 -12.464670 -6.430358 -3.722897  0.708929   \n",
       "RHH2177 -10.289315 -6.113236  -9.612803  0.756091 -4.753741 -3.594752   \n",
       "RHH2178 -10.098690 -5.891169 -12.298502  1.642715 -3.403522  1.604877   \n",
       "RHH2179  -9.486539 -5.705740 -12.306209  1.522135 -5.242828 -1.395477   \n",
       "RHH2180 -11.620982 -6.841660  -7.702711  3.060791 -4.710782 -1.831584   \n",
       "\n",
       "              302      1012  \n",
       "RHH2176 -0.279589 -0.499479  \n",
       "RHH2177 -0.380591 -0.011896  \n",
       "RHH2178  0.143180 -1.247750  \n",
       "RHH2179  1.562695 -2.288641  \n",
       "RHH2180 -5.050828 -1.279218  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cadrres_cell_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:44.727919Z",
     "start_time": "2020-11-17T13:24:44.722900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 1116\n"
     ]
    }
   ],
   "source": [
    "drug_list = cadrres_cell_df.columns\n",
    "cluster_list = cadrres_cell_df.index\n",
    "print(len(drug_list), len(cluster_list))\n",
    "\n",
    "drug_info_df = drug_info_df.loc[drug_list]\n",
    "cadrres_cell_df = cadrres_cell_df[drug_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:46.258031Z",
     "start_time": "2020-11-17T13:24:46.255722Z"
    }
   },
   "outputs": [],
   "source": [
    "if dosage_shifted:\n",
    "    # Shift by 4 uM\n",
    "    cadrres_cluster_df = cadrres_cluster_df - 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Read cell cluster % in each patient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:54.839582Z",
     "start_time": "2020-11-17T13:24:54.837443Z"
    }
   },
   "outputs": [],
   "source": [
    "# freq_df = pd.read_excel('../preprocessed_data/HN_patient_specific/percent_patient_tpm_cluster.xlsx', index_col=[0, 1]).reset_index()\n",
    "# freq_df = freq_df.pivot(index='patient_id', columns='cluster', values='percent').fillna(0) / 100\n",
    "\n",
    "# patient_list = freq_df.index\n",
    "\n",
    "# freq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:55.174156Z",
     "start_time": "2020-11-17T13:24:55.161222Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>cell_line_id</th>\n",
       "      <th>origin</th>\n",
       "      <th>batch</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>RHH2176</th>\n",
       "      <td>HN120</td>\n",
       "      <td>HN120P</td>\n",
       "      <td>Primary</td>\n",
       "      <td>RHH</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2177</th>\n",
       "      <td>HN120</td>\n",
       "      <td>HN120P</td>\n",
       "      <td>Primary</td>\n",
       "      <td>RHH</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2178</th>\n",
       "      <td>HN120</td>\n",
       "      <td>HN120P</td>\n",
       "      <td>Primary</td>\n",
       "      <td>RHH</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2179</th>\n",
       "      <td>HN120</td>\n",
       "      <td>HN120P</td>\n",
       "      <td>Primary</td>\n",
       "      <td>RHH</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2180</th>\n",
       "      <td>HN120</td>\n",
       "      <td>HN120P</td>\n",
       "      <td>Primary</td>\n",
       "      <td>RHH</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        patient_id cell_line_id   origin batch\n",
       "RHH2176      HN120       HN120P  Primary   RHH\n",
       "RHH2177      HN120       HN120P  Primary   RHH\n",
       "RHH2178      HN120       HN120P  Primary   RHH\n",
       "RHH2179      HN120       HN120P  Primary   RHH\n",
       "RHH2180      HN120       HN120P  Primary   RHH"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cell_info_df = pd.read_csv('../preprocessed_data/HN_patient_specific/cell_info.csv', index_col=0)\n",
    "cell_info_df = cell_info_df.loc[cadrres_cell_df.index]\n",
    "cell_info_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:55.500898Z",
     "start_time": "2020-11-17T13:24:55.496775Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['HN120', 'HN137', 'HN148', 'HN159', 'HN160', 'HN182'], dtype=object)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "patient_list = cell_info_df['patient_id'].unique()\n",
    "patient_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### List all pairs of patient and drug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:55.723603Z",
     "start_time": "2020-11-17T13:24:55.714007Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "pred_delta_df = pd.DataFrame(cadrres_cell_df.values - drug_info_df[ref_type].values, columns=drug_list, index=cluster_list)\n",
    "pred_cv_df = 100 / (1 + (np.power(2, -pred_delta_df)))\n",
    "pred_kill_df = 100 - pred_cv_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:56.281976Z",
     "start_time": "2020-11-17T13:24:56.265046Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1007</th>\n",
       "      <th>133</th>\n",
       "      <th>201</th>\n",
       "      <th>1010</th>\n",
       "      <th>182</th>\n",
       "      <th>301</th>\n",
       "      <th>302</th>\n",
       "      <th>1012</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>RHH2176</th>\n",
       "      <td>8.666278</td>\n",
       "      <td>92.867722</td>\n",
       "      <td>93.898495</td>\n",
       "      <td>97.080501</td>\n",
       "      <td>62.941906</td>\n",
       "      <td>88.168353</td>\n",
       "      <td>90.989682</td>\n",
       "      <td>69.828169</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2177</th>\n",
       "      <td>58.516692</td>\n",
       "      <td>90.788528</td>\n",
       "      <td>68.068353</td>\n",
       "      <td>18.586010</td>\n",
       "      <td>77.630330</td>\n",
       "      <td>99.325077</td>\n",
       "      <td>91.547421</td>\n",
       "      <td>62.273231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2178</th>\n",
       "      <td>55.277508</td>\n",
       "      <td>89.417912</td>\n",
       "      <td>93.204274</td>\n",
       "      <td>10.990625</td>\n",
       "      <td>57.648288</td>\n",
       "      <td>80.018420</td>\n",
       "      <td>88.281220</td>\n",
       "      <td>79.540035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2179</th>\n",
       "      <td>44.709356</td>\n",
       "      <td>88.138675</td>\n",
       "      <td>93.238032</td>\n",
       "      <td>11.835311</td>\n",
       "      <td>82.966591</td>\n",
       "      <td>96.973801</td>\n",
       "      <td>73.796178</td>\n",
       "      <td>88.887303</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHH2180</th>\n",
       "      <td>78.023937</td>\n",
       "      <td>94.229574</td>\n",
       "      <td>36.191544</td>\n",
       "      <td>4.416577</td>\n",
       "      <td>77.108980</td>\n",
       "      <td>97.745457</td>\n",
       "      <td>99.638682</td>\n",
       "      <td>79.892713</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHO713</th>\n",
       "      <td>37.870813</td>\n",
       "      <td>91.607442</td>\n",
       "      <td>17.584649</td>\n",
       "      <td>53.926024</td>\n",
       "      <td>51.019792</td>\n",
       "      <td>56.454351</td>\n",
       "      <td>98.595270</td>\n",
       "      <td>68.886323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHO714</th>\n",
       "      <td>48.434243</td>\n",
       "      <td>82.894292</td>\n",
       "      <td>92.994384</td>\n",
       "      <td>52.192966</td>\n",
       "      <td>26.767225</td>\n",
       "      <td>92.893844</td>\n",
       "      <td>40.117381</td>\n",
       "      <td>81.908822</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHO715</th>\n",
       "      <td>77.328245</td>\n",
       "      <td>88.541155</td>\n",
       "      <td>61.259013</td>\n",
       "      <td>7.754993</td>\n",
       "      <td>65.461802</td>\n",
       "      <td>99.658395</td>\n",
       "      <td>95.397265</td>\n",
       "      <td>83.287039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHO716</th>\n",
       "      <td>50.734027</td>\n",
       "      <td>84.971952</td>\n",
       "      <td>80.151082</td>\n",
       "      <td>33.738551</td>\n",
       "      <td>30.443759</td>\n",
       "      <td>81.050057</td>\n",
       "      <td>62.068290</td>\n",
       "      <td>83.559756</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RHO717</th>\n",
       "      <td>64.689195</td>\n",
       "      <td>86.000888</td>\n",
       "      <td>76.139982</td>\n",
       "      <td>56.946455</td>\n",
       "      <td>83.028406</td>\n",
       "      <td>85.767991</td>\n",
       "      <td>98.701439</td>\n",
       "      <td>82.530086</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1116 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "              1007        133        201       1010        182        301  \\\n",
       "RHH2176   8.666278  92.867722  93.898495  97.080501  62.941906  88.168353   \n",
       "RHH2177  58.516692  90.788528  68.068353  18.586010  77.630330  99.325077   \n",
       "RHH2178  55.277508  89.417912  93.204274  10.990625  57.648288  80.018420   \n",
       "RHH2179  44.709356  88.138675  93.238032  11.835311  82.966591  96.973801   \n",
       "RHH2180  78.023937  94.229574  36.191544   4.416577  77.108980  97.745457   \n",
       "...            ...        ...        ...        ...        ...        ...   \n",
       "RHO713   37.870813  91.607442  17.584649  53.926024  51.019792  56.454351   \n",
       "RHO714   48.434243  82.894292  92.994384  52.192966  26.767225  92.893844   \n",
       "RHO715   77.328245  88.541155  61.259013   7.754993  65.461802  99.658395   \n",
       "RHO716   50.734027  84.971952  80.151082  33.738551  30.443759  81.050057   \n",
       "RHO717   64.689195  86.000888  76.139982  56.946455  83.028406  85.767991   \n",
       "\n",
       "               302       1012  \n",
       "RHH2176  90.989682  69.828169  \n",
       "RHH2177  91.547421  62.273231  \n",
       "RHH2178  88.281220  79.540035  \n",
       "RHH2179  73.796178  88.887303  \n",
       "RHH2180  99.638682  79.892713  \n",
       "...            ...        ...  \n",
       "RHO713   98.595270  68.886323  \n",
       "RHO714   40.117381  81.908822  \n",
       "RHO715   95.397265  83.287039  \n",
       "RHO716   62.068290  83.559756  \n",
       "RHO717   98.701439  82.530086  \n",
       "\n",
       "[1116 rows x 8 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_kill_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:56.586328Z",
     "start_time": "2020-11-17T13:24:56.563029Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_kill_df.to_csv(out_dir + 'pred_drug_kill_{}_{}_cell.csv'.format(ref_type, model_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:56.907841Z",
     "start_time": "2020-11-17T13:24:56.890623Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_pateint_delta_df = pd.merge(pred_delta_df, cell_info_df[['patient_id']], left_index=True, right_index=True).groupby('patient_id').mean()\n",
    "pred_pateint_kill_df = pd.merge(pred_kill_df, cell_info_df[['patient_id']], left_index=True, right_index=True).groupby('patient_id').mean()\n",
    "\n",
    "pred_pateint_delta_df = pred_pateint_delta_df.stack().reset_index()\n",
    "pred_pateint_delta_df.columns = ['patient', 'drug_id', 'delta']\n",
    "\n",
    "pred_pateint_kill_df = pred_pateint_kill_df.stack().reset_index()\n",
    "pred_pateint_kill_df.columns = ['patient', 'drug_id', 'kill']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:57.088050Z",
     "start_time": "2020-11-17T13:24:57.074080Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient</th>\n",
       "      <th>drug_id</th>\n",
       "      <th>delta</th>\n",
       "      <th>kill</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>HN120</td>\n",
       "      <td>1007</td>\n",
       "      <td>-0.462341</td>\n",
       "      <td>56.818499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>HN120</td>\n",
       "      <td>133</td>\n",
       "      <td>-3.107634</td>\n",
       "      <td>87.297331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>HN120</td>\n",
       "      <td>201</td>\n",
       "      <td>-2.125949</td>\n",
       "      <td>76.854282</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>HN120</td>\n",
       "      <td>1010</td>\n",
       "      <td>1.480563</td>\n",
       "      <td>33.410070</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>HN120</td>\n",
       "      <td>182</td>\n",
       "      <td>-1.967096</td>\n",
       "      <td>75.642212</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  patient drug_id     delta       kill\n",
       "0   HN120    1007 -0.462341  56.818499\n",
       "1   HN120     133 -3.107634  87.297331\n",
       "2   HN120     201 -2.125949  76.854282\n",
       "3   HN120    1010  1.480563  33.410070\n",
       "4   HN120     182 -1.967096  75.642212"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_drug_pred_df = pd.merge(pred_pateint_delta_df, pred_pateint_kill_df, left_on=['patient', 'drug_id'], right_on=['patient', 'drug_id'])\n",
    "single_drug_pred_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:57.258797Z",
     "start_time": "2020-11-17T13:24:57.252271Z"
    }
   },
   "outputs": [],
   "source": [
    "single_drug_pred_df.loc[:, 'drug_name'] = [drug_id_name_dict[d] for d in single_drug_pred_df['drug_id'].values]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-11-17T13:24:57.620996Z",
     "start_time": "2020-11-17T13:24:57.610876Z"
    }
   },
   "outputs": [],
   "source": [
    "if dosage_shifted:\n",
    "    single_drug_pred_df.to_csv(out_dir + 'pred_drug_kill_{}_{}_shifted.csv'.format(ref_type, model_name), index=False)\n",
    "else:\n",
    "    single_drug_pred_df.to_csv(out_dir + 'pred_drug_kill_{}_{}.csv'.format(ref_type, model_name), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "175.531px"
   },
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
